// Copyright (C) Kumo inc. and its affiliates.
// Author: Jeff.li lijippy@163.com
// All rights reserved.
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
//

#include <algorithm>
#include <cstdint>
#include <string>

#include <benchmark/benchmark.h>

#include <nebula/core/memory_pool.h>
#include <nebula/types/type_fwd.h>
#include <nebula/util/cpu_info.h>
#include <turbo/log/logging.h>  // IWYU pragma: keep

namespace nebula {

    // Benchmark changed its parameter type between releases from
    // int to int64_t. As it doesn't have version macros, we need
    // to apply C++ template magic.

    template<typename Func>
    struct BenchmarkArgsType;

    // Pattern matching that extracts the vector element type of Benchmark::Args()
    template<typename Values>
    struct BenchmarkArgsType<benchmark::internal::Benchmark *(
    benchmark::internal::Benchmark::*)(const std::vector<Values> &)> {
        using type = Values;
    };

    using ArgsType =
            typename BenchmarkArgsType<decltype(&benchmark::internal::Benchmark::Args)>::type;

    using internal::CpuInfo;

    static const CpuInfo *cpu_info = CpuInfo::GetInstance();

    static const int64_t kL1Size = cpu_info->CacheSize(CpuInfo::CacheLevel::L1);
    static const int64_t kL2Size = cpu_info->CacheSize(CpuInfo::CacheLevel::L2);
    static const int64_t kL3Size = cpu_info->CacheSize(CpuInfo::CacheLevel::L3);
    static const int64_t kCantFitInL3Size = kL3Size * 4;
    static const std::vector<int64_t> kMemorySizes = {kL1Size, kL2Size, kL3Size,
                                                      kCantFitInL3Size};
    // 0 is treated as "no nulls"
    static const std::vector<ArgsType> kInverseNullProportions = {10000, 100, 10, 2, 1, 0};

    struct GenericItemsArgs {
        // number of items processed per iteration
        const int64_t size;

        // proportion of nulls in generated arrays
        double null_proportion;

        explicit GenericItemsArgs(benchmark::State &state)
                : size(state.range(0)), state_(state) {
            if (state.range(1) == 0) {
                this->null_proportion = 0.0;
            } else {
                this->null_proportion = std::min(1., 1. / static_cast<double>(state.range(1)));
            }
        }

        ~GenericItemsArgs() {
            state_.counters["size"] = static_cast<double>(size);
            state_.counters["null_percent"] = null_proportion * 100;
            state_.SetItemsProcessed(state_.iterations() * size);
        }

    private:
        benchmark::State &state_;
    };

    void BenchmarkSetArgsWithSizes(benchmark::internal::Benchmark *bench,
                                   const std::vector<int64_t> &sizes = kMemorySizes) {
        bench->Unit(benchmark::kMicrosecond);

        for (const auto size: sizes) {
            for (const auto inverse_null_proportion: kInverseNullProportions) {
                bench->Args({static_cast<ArgsType>(size), inverse_null_proportion});
            }
        }
    }

    void BenchmarkSetArgs(benchmark::internal::Benchmark *bench) {
        BenchmarkSetArgsWithSizes(bench, kMemorySizes);
    }

    void RegressionSetArgs(benchmark::internal::Benchmark *bench) {
        // Regression do not need to account for cache hierarchy, thus optimize for
        // the best case.
        BenchmarkSetArgsWithSizes(bench, {kL1Size});
    }

    // RAII struct to handle some of the boilerplate in regression benchmarks
    struct RegressionArgs {
        // size of memory tested (per iteration) in bytes
        int64_t size;

        // proportion of nulls in generated arrays
        double null_proportion;

        // If size_is_bytes is true, then it's a number of bytes, otherwise it's the
        // number of items processed (for reporting)
        explicit RegressionArgs(benchmark::State &state, bool size_is_bytes = true)
                : size(state.range(0)), state_(state), size_is_bytes_(size_is_bytes) {
            if (state.range(1) == 0) {
                this->null_proportion = 0.0;
            } else {
                this->null_proportion = std::min(1., 1. / static_cast<double>(state.range(1)));
            }
        }

        ~RegressionArgs() {
            state_.counters["size"] = static_cast<double>(size);
            state_.counters["null_percent"] = null_proportion * 100;
            if (size_is_bytes_) {
                state_.SetBytesProcessed(state_.iterations() * size);
            } else {
                state_.SetItemsProcessed(state_.iterations() * size);
            }
        }

    private:
        benchmark::State &state_;
        bool size_is_bytes_;
    };

    class MemoryPoolMemoryManager : public benchmark::MemoryManager {
        void Start() override {
            memory_pool = std::make_shared<ProxyMemoryPool>(default_memory_pool());

            MemoryPool *default_pool = default_memory_pool();
            global_allocations_start = default_pool->num_allocations();
        }

        // BENCHMARK_DONT_OPTIMIZE is used here to detect Google Benchmark
        // 1.8.0. We can remove this Stop(turbo::Result*) when we require Google
        // Benchmark 1.8.0 or later.
#ifndef BENCHMARK_DONT_OPTIMIZE

        void Stop(turbo::Result *result) override { Stop(*result); }

#endif

        void Stop(benchmark::MemoryManager::Result &result) override {
            // If num_allocations is still zero, we assume that the memory pool wasn't passed down
            // so we should record them.
            MemoryPool *default_pool = default_memory_pool();
            int64_t new_default_allocations =
                    default_pool->num_allocations() - global_allocations_start;

            // Only record metrics if (1) there were allocations and (2) we
            // recorded at least one.
            if (new_default_allocations > 0 && memory_pool->num_allocations() > 0) {
                if (new_default_allocations > memory_pool->num_allocations()) {
                    // If we missed some, let's report that.
                    int64_t missed_allocations =
                            new_default_allocations - memory_pool->num_allocations();
                    KLOG(WARNING) << "BenchmarkMemoryTracker recorded some allocations "
                                        << "for a benchmark, but missed " << missed_allocations
                                        << " allocations.\n";
                }

                result.max_bytes_used = memory_pool->max_memory();
                result.total_allocated_bytes = memory_pool->total_bytes_allocated();
                result.num_allocs = memory_pool->num_allocations();
            }
        }

    public:
        std::shared_ptr<::nebula::ProxyMemoryPool> memory_pool;

    protected:
        int64_t global_allocations_start;
    };

    /// \brief Track memory pool allocations in benchmarks.
    ///
    /// Instantiate as a global variable to register the hooks into Google Benchmark
    /// to collect memory metrics. Before each benchmark, a new ProxyMemoryPool is
    /// created. It can then be accessed with memory_pool(). Once the benchmark is
    /// complete, the hook will record the maximum memory used, the total bytes
    /// allocated, and the total number of allocations. If no allocations were seen,
    /// (for example, if you forgot to pass down the memory pool), then these metrics
    /// will not be saved.
    ///
    /// Since this is used as one global variable, this will not work if multiple
    /// benchmarks are run concurrently or for multi-threaded benchmarks (ones
    /// that use `->ThreadRange(...)`).
    class BenchmarkMemoryTracker {
    public:
        BenchmarkMemoryTracker() : manager_() { ::benchmark::RegisterMemoryManager(&manager_); }

        ::nebula::MemoryPool *memory_pool() const { return manager_.memory_pool.get(); }

    protected:
        ::nebula::MemoryPoolMemoryManager manager_;
    };

}  // namespace nebula
